# Copyright (c) 2017 NVIDIA Corporation

# to run against cuda:
# --gpu_ids 0 --path_to_train_data Netflix/N1W_TRAIN --path_to_eval_data Netflix/N1W_VALID --hidden_layers 512,512,1024 --non_linearity_type selu --batch_size 128 --logdir model_save --drop_prob 0.8 --optimizer momentum --lr 0.005 --weight_decay 0 --aug_step 1 --noise_prob 0 --num_epochs 1 --summary_frequency 1000 --forcecuda

# to run on cpu:
# --gpu_ids 0 --path_to_train_data Netflix/N1W_TRAIN --path_to_eval_data Netflix/N1W_VALID --hidden_layers 512,512,1024 --non_linearity_type selu --batch_size 128 --logdir model_save --drop_prob 0.8 --optimizer momentum --lr 0.005 --weight_decay 0 --aug_step 1 --noise_prob 0 --num_epochs 1 --summary_frequency 1000 --forcecpu


import argparse
import copy
import os
import time

# from .logger import Logger
from math import sqrt
from pathlib import Path

import numpy as np
import torch
import torch.autograd.profiler as profiler
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.optim.lr_scheduler import MultiStepLR

from .reco_encoder.data import input_layer
from .reco_encoder.model import model


def getTrainBenchmarkArgs():

    class Args:
        pass

    args = Args()
    args.lr = 0.005
    args.weight_decay = 0
    args.drop_prob = 0.8
    args.noise_prob = 0
    args.batch_size = 128
    args.summary_frequency = 1000
    args.aug_step = 1
    args.constrained = False
    args.skip_last_layer_nl = False
    args.num_epochs = 1
    args.save_every = 3
    args.optimizer = "momentum"
    args.hidden_layers = "512,512,1024"
    args.gpu_ids = "0"
    args.path_to_train_data = os.path.dirname(__file__) + "/Netflix/N1W_TRAIN"
    args.path_to_eval_data = os.path.dirname(__file__) + "/Netflix/N1W_VALID"
    args.non_linearity_type = "selu"
    args.logdir = "model_save"
    args.nooutput = True
    args.silent = True
    args.forcecuda = False
    args.forcecpu = False
    args.profile = False

    return args


def getTrainCommandLineArgs():

    parser = argparse.ArgumentParser(description="RecoEncoder")
    parser.add_argument(
        "--lr", type=float, default=0.00001, metavar="N", help="learning rate"
    )
    parser.add_argument(
        "--weight_decay", type=float, default=0.0, metavar="N", help="L2 weight decay"
    )
    parser.add_argument(
        "--drop_prob",
        type=float,
        default=0.0,
        metavar="N",
        help="dropout drop probability",
    )
    parser.add_argument(
        "--noise_prob", type=float, default=0.0, metavar="N", help="noise probability"
    )
    parser.add_argument(
        "--batch_size", type=int, default=64, metavar="N", help="global batch size"
    )
    parser.add_argument(
        "--summary_frequency",
        type=int,
        default=100,
        metavar="N",
        help="how often to save summaries",
    )
    parser.add_argument(
        "--aug_step",
        type=int,
        default=-1,
        metavar="N",
        help="do data augmentation every X step",
    )
    parser.add_argument(
        "--constrained", action="store_true", help="constrained autoencoder"
    )
    parser.add_argument(
        "--skip_last_layer_nl",
        action="store_true",
        help="if present, decoder's last layer will not apply non-linearity function",
    )
    parser.add_argument(
        "--num_epochs",
        type=int,
        default=50,
        metavar="N",
        help="maximum number of epochs",
    )
    parser.add_argument(
        "--save_every",
        type=int,
        default=3,
        metavar="N",
        help="save every N number of epochs",
    )
    parser.add_argument(
        "--optimizer",
        type=str,
        default="momentum",
        metavar="N",
        help="optimizer kind: adam, momentum, adagrad or rmsprop",
    )
    parser.add_argument(
        "--hidden_layers",
        type=str,
        default="1024,512,512,128",
        metavar="N",
        help="hidden layer sizes, comma-separated",
    )
    parser.add_argument(
        "--gpu_ids",
        type=str,
        default="0",
        metavar="N",
        help="comma-separated gpu ids to use for data parallel training",
    )
    parser.add_argument(
        "--path_to_train_data",
        type=str,
        default="",
        metavar="N",
        help="Path to training data",
    )
    parser.add_argument(
        "--path_to_eval_data",
        type=str,
        default="",
        metavar="N",
        help="Path to evaluation data",
    )
    parser.add_argument(
        "--non_linearity_type",
        type=str,
        default="selu",
        metavar="N",
        help="type of the non-linearity used in activations",
    )
    parser.add_argument(
        "--logdir",
        type=str,
        default="logs",
        metavar="N",
        help="where to save model and write logs",
    )
    parser.add_argument(
        "--nooutput", action="store_true", help="disable writing output to file"
    )
    parser.add_argument("--silent", action="store_true", help="disable all messages")
    parser.add_argument("--forcecuda", action="store_true", help="force cuda use")
    parser.add_argument("--forcecpu", action="store_true", help="force cpu use")
    parser.add_argument(
        "--profile", action="store_true", help="enable profiler and stat print"
    )

    args = parser.parse_args()

    return args


def processTrainArgState(args):

    if not args.silent:
        print(args)

    if args.forcecpu and args.forcecuda:
        print("Error, force cpu and cuda cannot both be set")
        quit()

    args.use_cuda = torch.cuda.is_available()  # global flag
    args.use_xpu = args.device == "xpu"
    if not args.silent:
        if args.use_cuda or args.use_xpu:
            print("GPU is available.")
        else:
            print("GPU is not available.")

    if args.use_cuda and args.forcecpu:
        args.use_cuda = False

    if not args.silent:
        if args.use_cuda or args.use_xpu:
            print("Running On GPU")
        else:
            print("Running On CPU")

    return args


def log_var_and_grad_summaries(
    logger, layers, global_step, prefix, log_histograms=False
):
    """
    Logs variable and grad stats for layer. Transfers data from GPU to CPU automatically
    :param logger: TB logger
    :param layers: param list
    :param global_step: global step for TB
    :param prefix: name prefix
    :param log_histograms: (default: False) whether or not log histograms
    :return:
    """
    for ind, w in enumerate(layers):
        # Variables
        w_var = w.data.cpu().numpy()
        logger.scalar_summary(
            "Variables/FrobNorm/{}_{}".format(prefix, ind),
            np.linalg.norm(w_var),
            global_step,
        )
        if log_histograms:
            logger.histo_summary(
                tag="Variables/{}_{}".format(prefix, ind),
                values=w.data.cpu().numpy(),
                step=global_step,
            )

        # Gradients
        w_grad = w.grad.data.cpu().numpy()
        logger.scalar_summary(
            "Gradients/FrobNorm/{}_{}".format(prefix, ind),
            np.linalg.norm(w_grad),
            global_step,
        )
        if log_histograms:
            logger.histo_summary(
                tag="Gradients/{}_{}".format(prefix, ind),
                values=w.grad.data.cpu().numpy(),
                step=global_step,
            )


def DoTrainEval(encoder, evaluation_data_layer, device):
    encoder.eval()
    denom = 0.0
    total_epoch_loss = 0.0
    for i, (eval, src) in enumerate(evaluation_data_layer.iterate_one_epoch_eval()):
        inputs = Variable(src.to(device).to_dense())
        targets = Variable(eval.to(device).to_dense())
        outputs = encoder(inputs)
        loss, num_ratings = model.MSEloss(outputs, targets)
        total_epoch_loss += loss.item()
        denom += num_ratings.item()
    return sqrt(total_epoch_loss / denom)


class DeepRecommenderTrainBenchmark:

    def __init__(
        self, device="cpu", jit=False, batch_size=256, processCommandLine=False
    ):
        self.TrainInit(device, jit, batch_size, processCommandLine)

    def TrainInit(
        self, device="cpu", jit=False, batch_size=256, processCommandLine=False
    ):

        # Force test to run in toy mode. Single call of fake data to model.
        self.toytest = True
        self.toybatch = batch_size

        # number of movies in netflix training set.
        self.toyvocab = 197951

        self.toyinputs = torch.randn(self.toybatch, self.toyvocab)

        if processCommandLine:
            self.args = getTrainCommandLineArgs()
        else:
            self.args = getTrainBenchmarkArgs()

            if device == "cpu":
                forcecuda = False
            elif device == "cuda":
                forcecuda = True
            elif device == "xpu":
                forcecuda = False
            else:
                # unknown device string, quit init
                return

            self.args.forcecuda = forcecuda
            self.args.forcecpu = not forcecuda and device == "cpu"
            self.args.device = device

        self.args = processTrainArgState(self.args)

        if self.toytest == False:
            self.logger = Logger(self.args.logdir)
        self.params = dict()
        self.params["batch_size"] = self.args.batch_size
        self.params["data_dir"] = self.args.path_to_train_data
        self.params["major"] = "users"
        self.params["itemIdInd"] = 1
        self.params["userIdInd"] = 0

        if self.toytest == False:
            if not self.args.silent:
                print("Loading training data")

            self.data_layer = input_layer.UserItemRecDataProvider(params=self.params)
            if not self.args.silent:
                print("Data loaded")
                print("Total items found: {}".format(len(self.data_layer.data.keys())))
                print("Vector dim: {}".format(self.data_layer.vector_dim))

                print("Loading eval data")

        self.eval_params = copy.deepcopy(self.params)

        # must set eval batch size to 1 to make sure no examples are missed
        if self.toytest:
            self.rencoder = model.AutoEncoder(
                layer_sizes=[self.toyvocab]
                + [int(l) for l in self.args.hidden_layers.split(",")],
                nl_type=self.args.non_linearity_type,
                is_constrained=self.args.constrained,
                dp_drop_prob=self.args.drop_prob,
                last_layer_activations=not self.args.skip_last_layer_nl,
            )
        else:
            self.eval_params["data_dir"] = self.args.path_to_eval_data
            self.eval_data_layer = input_layer.UserItemRecDataProvider(
                params=self.eval_params,
                user_id_map=self.data_layer.userIdMap,  # the mappings are provided
                item_id_map=self.data_layer.itemIdMap,
            )
            self.eval_data_layer.src_data = self.data_layer.data
            self.rencoder = model.AutoEncoder(
                layer_sizes=[self.data_layer.vector_dim]
                + [int(l) for l in self.args.hidden_layers.split(",")],
                nl_type=self.args.non_linearity_type,
                is_constrained=self.args.constrained,
                dp_drop_prob=self.args.drop_prob,
                last_layer_activations=not self.args.skip_last_layer_nl,
            )

            os.makedirs(self.args.logdir, exist_ok=True)
            self.model_checkpoint = self.args.logdir + "/model"
            self.path_to_model = Path(self.model_checkpoint)
            if self.path_to_model.is_file():
                print("Loading model from: {}".format(self.model_checkpoint))
                self.rencoder.load_state_dict(torch.load(self.model_checkpoint))

        if not self.args.silent:
            print("######################################################")
            print("######################################################")
            print("############# AutoEncoder Model: #####################")
            print(self.rencoder)
            print("######################################################")
            print("######################################################")

        if self.args.use_cuda:
            gpu_ids = [int(g) for g in self.args.gpu_ids.split(",")]
            if not self.args.silent:
                print("Using GPUs: {}".format(gpu_ids))

            if len(gpu_ids) > 1:
                self.rencoder = nn.DataParallel(self.rencoder, device_ids=gpu_ids)

        self.toyinputs = self.toyinputs.to(device)
        self.rencoder = self.rencoder.to(device)

        if self.args.optimizer == "adam":
            self.optimizer = optim.Adam(
                self.rencoder.parameters(),
                lr=self.args.lr,
                weight_decay=self.args.weight_decay,
            )
        elif self.args.optimizer == "adagrad":
            self.optimizer = optim.Adagrad(
                self.rencoder.parameters(),
                lr=self.args.lr,
                weight_decay=self.args.weight_decay,
            )
        elif self.args.optimizer == "momentum":
            self.optimizer = optim.SGD(
                self.rencoder.parameters(),
                lr=self.args.lr,
                momentum=0.9,
                weight_decay=self.args.weight_decay,
            )
            self.scheduler = MultiStepLR(
                self.optimizer, milestones=[24, 36, 48, 66, 72], gamma=0.5
            )
        elif args.optimizer == "rmsprop":
            self.optimizer = optim.RMSprop(
                self.rencoder.parameters(),
                lr=self.args.lr,
                momentum=0.9,
                weight_decay=self.args.weight_decay,
            )
        else:
            raise ValueError("Unknown optimizer kind")

        self.t_loss = 0.0
        self.t_loss_denom = 0.0
        self.denom = 0.0
        self.total_epoch_loss = 0.0
        self.global_step = 0

        if self.args.noise_prob > 0.0:
            self.dp = nn.Dropout(p=self.args.noise_prob)

    def get_optimizer(self):
        return self.optimizer

    def set_optimizer(self, optimizer):
        self.optimizer = optimizer

    def DoTrain(self):

        self.rencoder.train()
        # if self.args.optimizer == "momentum":
        #  self.scheduler.step()

        for i, mb in enumerate(self.data_layer.iterate_one_epoch()):

            inputs = Variable(mb.to(self.args.device).to_dense())

            self.optimizer.zero_grad()

            outputs = self.rencoder(inputs)

            loss, num_ratings = model.MSEloss(outputs, inputs)
            loss = loss / num_ratings
            loss.backward()
            self.optimizer.step()
            self.global_step += 1
            self.t_loss += loss.item()
            self.t_loss_denom += 1

            if not self.args.nooutput:
                if i % self.args.summary_frequency == 0:
                    print(
                        "[%d, %5d] RMSE: %.7f"
                        % (self.epoch, i, sqrt(self.t_loss / self.t_loss_denom))
                    )
                    self.logger.scalar_summary(
                        "Training_RMSE",
                        sqrt(self.t_loss / self.t_loss_denom),
                        self.global_step,
                    )
                    self.t_loss = 0
                    self.t_loss_denom = 0.0
                    log_var_and_grad_summaries(
                        self.logger,
                        self.rencoder.encode_w,
                        self.global_step,
                        "Encode_W",
                    )
                    log_var_and_grad_summaries(
                        self.logger,
                        self.rencoder.encode_b,
                        self.global_step,
                        "Encode_b",
                    )
                    if not self.rencoder.is_constrained:
                        log_var_and_grad_summaries(
                            self.logger,
                            self.rencoder.decode_w,
                            self.global_step,
                            "Decode_W",
                        )
                    log_var_and_grad_summaries(
                        self.logger,
                        self.rencoder.decode_b,
                        self.global_step,
                        "Decode_b",
                    )

            self.total_epoch_loss += loss.item()
            self.denom += 1

            # if args.aug_step > 0 and i % args.aug_step == 0 and i > 0:
            if self.args.aug_step > 0:
                # Magic data augmentation trick happen here
                for t in range(self.args.aug_step):
                    inputs = Variable(outputs.data)
                    if self.args.noise_prob > 0.0:
                        inputs = dp(inputs)
                    self.optimizer.zero_grad()
                    outputs = self.rencoder(inputs)
                    loss, num_ratings = model.MSEloss(outputs, inputs)
                    loss = loss / num_ratings
                    loss.backward()
                    self.optimizer.step()

    def train(self, niter=1):
        for self.epoch in range(niter):

            if self.toytest:
                self.rencoder.train()
                self.optimizer.zero_grad()
                outputs = self.rencoder(self.toyinputs)
                loss, num_ratings = model.MSEloss(outputs, self.toyinputs)
                loss = loss / num_ratings
                loss.backward()
                self.optimizer.step()
                continue

            if not self.args.silent:
                print("Doing epoch {} of {}".format(self.epoch, niter))
                print("Timing Start")
                e_start_time = time.time()

            self.DoTrain()

            if not self.args.silent:
                e_end_time = time.time()
                print("Timing End")

                if self.args.profile:
                    print(
                        prof.key_averages().table(
                            sort_by="cpu_time_total", row_limit=10
                        )
                    )
                    prof.export_chrome_trace("trace.json")

                print(
                    "Total epoch {} finished in {} seconds with TRAINING RMSE loss: {}".format(
                        self.epoch,
                        e_end_time - e_start_time,
                        sqrt(self.total_epoch_loss / self.denom),
                    )
                )

            if not self.args.silent:
                self.logger.scalar_summary(
                    "Training_RMSE_per_epoch",
                    sqrt(self.total_epoch_loss / self.denom),
                    self.epoch,
                )
                self.logger.scalar_summary(
                    "Epoch_time", e_end_time - e_start_time, self.epoch
                )
                if (
                    self.epoch % self.args.save_every == 0
                    or self.epoch == self.args.num_epochs - 1
                ):
                    eval_loss = DoTrainEval(
                        self.rencoder, self.eval_data_layer, self.args.device
                    )
                    print("Epoch {} EVALUATION LOSS: {}".format(self.epoch, eval_loss))

                    self.logger.scalar_summary("EVALUATION_RMSE", eval_loss, self.epoch)
                    print(
                        "Saving model to {}".format(
                            self.model_checkpoint + ".epoch_" + str(self.epoch)
                        )
                    )
                    torch.save(
                        self.rencoder.state_dict(),
                        self.model_checkpoint + ".epoch_" + str(self.epoch),
                    )

        if not self.args.nooutput:
            print("Saving model to {}".format(self.model_checkpoint + ".last"))
            torch.save(self.rencoder.state_dict(), self.model_checkpoint + ".last")

            # save to onnx
            dummy_input = Variable(
                torch.randn(self.params["batch_size"], self.data_layer.vector_dim).type(
                    torch.float
                )
            )
            torch.onnx.export(
                self.rencoder.float(),
                dummy_input.to(device),
                self.model_checkpoint + ".onnx",
                verbose=True,
            )
            print("ONNX model saved to {}!".format(self.model_checkpoint + ".onnx"))

    def TimedTrainingRun(self):
        if self.args.profile:
            with profiler.profile(
                record_shapes=True,
                use_cuda=self.args.use_cuda,
                use_xpu=self.args.use_xpu,
            ) as prof:
                with profiler.record_function("training_epoch"):
                    self.train(self.args.num_epochs)
        else:
            self.train(self.args.num_epochs)


def main():

    gpuTrain = DeepRecommenderTrainBenchmark(device="cuda")
    gpuTrain.TimedTrainingRun()

    gpuTrain = DeepRecommenderBenchmark(device="cpu")
    gpuTrain.TimedTrainingRun()


if __name__ == "__main__":

    main()
